Add EAGLE-3 draft head#20149
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20149
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 3 PendingAs of commit f6c749f with merge base dc55469 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR introduces an EAGLE-3 “draft head” module under examples/models/eagle3/ that can execute a single decoder-layer forward pass and load vLLM speculator-format safetensors checkpoints, along with unit tests covering forward execution, checkpoint loading (mono + sharded), vocab remapping tensors, RoPE precision behavior, and rejection of unsupported checkpoint variants.
Changes:
- Add
Eagle3Draft+ supporting modules/config, including a safetensors checkpoint loader for vLLM speculator checkpoints. - Add unit tests validating forward shapes, mapping behavior, checkpoint roundtrip (including sharded), and unsupported config variants.
- Register the new tests in
pytest.ini.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
pytest.ini |
Adds the new EAGLE-3 draft-head test file to pytest discovery. |
examples/models/eagle3/draft.py |
Implements the EAGLE-3 draft-head model and speculator safetensors checkpoint loading. |
examples/models/eagle3/test_draft.py |
Adds unit tests for forward pass, checkpoint loading, vocab mapping, and RoPE precision. |
examples/models/eagle3/__init__.py |
Declares the eagle3 examples package. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def fuse(self, aux: torch.Tensor) -> torch.Tensor: | ||
| """Fuse concatenated target aux hidden states (B,T,3*D) -> feature (B,T,D).""" | ||
| return self.fc(aux) |
| Draft ids map back to target ids with ``target_id = draft_id + d2t[draft_id]``. | ||
| Speculator checkpoints store the decoder layer under ``layers.0.*`` and may | ||
| include ``embed_tokens``, ``d2t``, and ``t2d``. |
| # d2t/t2d are index/mask tensors (their checkpoint shape differs from the | ||
| # placeholder buffers); register them directly, load the rest strict. | ||
| model.register_buffer("d2t", state_dict.pop("d2t"), persistent=False) | ||
| model.register_buffer("t2d", state_dict.pop("t2d"), persistent=False) | ||
| model.load_state_dict(state_dict, strict=True, assign=True) |
| @pytest.mark.parametrize("sharded", [False, True]) | ||
| def test_from_checkpoint_roundtrip(tmp_path, sharded): | ||
| cfg = tiny_config() | ||
| src, d2t, t2d = _write_checkpoint(str(tmp_path), cfg, sharded=sharded) | ||
|
|
||
| model, loaded_cfg = Eagle3Draft.from_checkpoint( | ||
| str(tmp_path), device="cpu", dtype=torch.float32 | ||
| ) | ||
| assert loaded_cfg.has_own_embed | ||
| assert loaded_cfg.aux_hidden_state_layers == cfg.aux_hidden_state_layers | ||
| assert loaded_cfg.target_vocab_size == cfg.target_vocab_size | ||
| torch.testing.assert_close( | ||
| model.midlayer.self_attn.q_proj.weight, src.midlayer.self_attn.q_proj.weight | ||
| ) | ||
| torch.testing.assert_close(model.fc.weight, src.fc.weight) | ||
| assert torch.equal(model.d2t, d2t) | ||
| assert torch.equal(model.t2d, t2d) | ||
| assert model.midlayer.self_attn.inv_freq.dtype == torch.float32 | ||
| T = 4 | ||
| feat = model.fuse(torch.randn(1, T, 3 * cfg.target_hidden_size)) | ||
| emb = model.embed(torch.randint(0, cfg.target_vocab_size, (T,))).unsqueeze(0) | ||
| logits, g = model(emb, feat, torch.arange(T)) | ||
| assert logits.shape == (1, T, cfg.draft_vocab_size) |
Adds the EAGLE-3 draft-head module, vLLM speculator checkpoint loading,
reduced-vocabulary id mapping, and tests for forward execution,
checkpoint loading, vocab remapping, RoPE precision, and unsupported
checkpoint variants.
Authored with assistance from Claude Code.